-
Notifications
You must be signed in to change notification settings - Fork 6.8k
[MXNET-1426] Fix the wrong result of sum, mean, argmin, argmax when inputs contain inf or nan #16234
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Julia part looks fine for me.
Hi @marcoabreu @access2rohit , could you please help take a review? |
Hi @reminisce and @haojin2 , could you please help take a review? This PR makes the following functions consistent with NumPy.
Thank you! |
Hi @eric-haibin-lin , could you please help take a review? Thank you so much! |
Would you mind also add what is the result before this fix? |
Hi @eric-haibin-lin , I have updated the test result : ) |
cc @reminisce |
Ping : ) |
3rdparty/mshadow/mshadow/base.h
Outdated
} | ||
template<> | ||
MSHADOW_XINLINE bool IsNan(volatile mshadow::half::half_t val) { | ||
return (val.half_ & 0x7fff) > 0x7c00; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you turn these magic values into constants with documentation? While I get 0x7ffff, 0x7c00 for example, looks quite arbitrary.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @marcoabreu , I add two constants MSHADOW_HALF_SIGN_BIT
and MSHADOW_HALF_EXPONENT_BITS
in 3rdparty/mshadow/mshadow/half.h, and replace these two magic values.
if type(ndarray_ret) is mx.ndarray.NDArray: | ||
ndarray_ret = ndarray_ret.asnumpy() | ||
assert (ndarray_ret.shape == numpy_ret.shape) or \ | ||
(ndarray_ret.shape == (1,) and numpy_ret.shape == ()), "nd:%s, numpy:%s" \ | ||
%(ndarray_ret.shape, numpy_ret.shape) | ||
err = np.square(ndarray_ret - numpy_ret).mean() | ||
assert err < 1E-4 | ||
if check_dtype: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you elaborate why you're introducing so much branching into a test? If the results are inconsistent, we should rather improve the test instead of skipping the checks. I'd love to have more detail
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @marcoabreu , here is the explanation.
-
So much branching
We need to test all reduce operators, likemin, max, argmin, argmax, sum, mean
when the inputs contain-inf, +inf, nan
. -
Skipping the checks
I replace the old check with a new one. : )
I'll merge after the feedback has been addressed :) Sorry for the delay |
Description
Hi, there.
I fix the wrong result of sum(inf, inf) and mean(inf, inf).
Test Case:
The wrong result in
mxnet_mkl-1.6.0b20191015-py2.py3-none-manylinux1_x86_64
The correct result in this PR:
If we modify the
test
function,Here is the result of NumPy.
Checklist
Essentials
Please feel free to remove inapplicable items for your PR.
Changes
isnan_typed
andisinf_typed
in mshadowisnan_typed
insrc/operator/mshadow_op.h
mshadow/extension/reduce_with_axis.h
to support NaNargmin
andargmax
.